{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tTLuI6Xz4yvr"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "i5R7Gb120Ekx"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XPflk8Ey7e9D"
      },
      "source": [
        "# トレーニングループの新規作成"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0pA8D5OrHGIx"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch\">     <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">TensorFlow.org で表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/guide/keras/writing_a_training_loop_from_scratch.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\"> Google Colab で実行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/guide/keras/writing_a_training_loop_from_scratch.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/guide/keras/writing_a_training_loop_from_scratch.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l7oVQp2POvp8"
      },
      "source": [
        "## セットアップ"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jkWTZxIyysDm"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TcNr4V6Z2eDw"
      },
      "source": [
        "## はじめに\n",
        "\n",
        "Keras はデフォルトでトレーニングと評価ループに所要できる`fit()`と`evaluate()`を提供しています。これらの使い方については、[「組み込みのメソッドによるトレーニングと評価」](https://www.tensorflow.org/guide/keras/train_and_evaluate/)を参照してください。\n",
        "\n",
        "`fit()`の利便性を活用しながらモデルの学習アルゴリズムをカスタマイズするには（`fit()`を使用して GAN をトレーニングする場合など）、`Model`クラスをサブクラス化し、`fit()`中に繰り返し呼び出される独自の`train_step()`メソッドを実装します。詳細については、「<a data-md-type=\"link\" href=\"https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit/\">`fit()`で行われる処理のカスタマイズ</a>」を参照してください。\n",
        "\n",
        "トレーニングと評価を非常に低レベルで制御する場合は、独自のトレーニングと評価ループを新規作成する必要があります。このガイドでは、トレーニングと評価ループを新規作成する方法を見ていきます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MvWJHarCmdOp"
      },
      "source": [
        "## `GradientTape`を使用するには: 最初のエンドツーエンドの例\n",
        "\n",
        "`GradientTape`スコープ内でモデルを呼び出すと、損失値に対するレイヤーのトレーニング可能な重みの勾配を取得できます。オプティマイザのインスタンスを使用すると、これらの勾配を使用してこれらの変数（`model.trainable_weights`を使用して取得できます）を更新できます。\n",
        "\n",
        "では、シンプルな MNIST モデルを見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xd8GpFudzEiU"
      },
      "outputs": [],
      "source": [
        "inputs = keras.Input(shape=(784,), name=\"digits\")\n",
        "x1 = layers.Dense(64, activation=\"relu\")(inputs)\n",
        "x2 = layers.Dense(64, activation=\"relu\")(x1)\n",
        "outputs = layers.Dense(10, name=\"predictions\")(x2)\n",
        "model = keras.Model(inputs=inputs, outputs=outputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9uXwjRTjgQKE"
      },
      "source": [
        "ミニバッチの勾配を使用してカスタムトレーニングループでトレーニングします。\n",
        "\n",
        "まず、オプティマイザ、損失関数、およびデータセットが必要です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xnxsNbPmMvqr"
      },
      "outputs": [],
      "source": [
        "# Instantiate an optimizer.\n",
        "optimizer = keras.optimizers.SGD(learning_rate=1e-3)\n",
        "# Instantiate a loss function.\n",
        "loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
        "\n",
        "# Prepare the training dataset.\n",
        "batch_size = 64\n",
        "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
        "x_train = np.reshape(x_train, (-1, 784))\n",
        "x_test = np.reshape(x_train, (-1, 784))\n",
        "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
        "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VmZj17qo05qZ"
      },
      "source": [
        "トレーニングループは以下のとおりです。\n",
        "\n",
        "- エポックをイテレートする`for`ループを開きます\n",
        "- エポックごとに、バッチでデータセットをイテレートする`for`ループを開きます\n",
        "- バッチごとに、`GradientTape()`スコープを開きます\n",
        "- このスコープ内で、モデル（フォワードパス）を呼び出し、損失を計算します\n",
        "- スコープ外で、損失に対するモデルの重みの勾配を取得します\n",
        "- 最後に、オプティマイザを使用して、勾配に基づいてモデルの重みを更新します"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KP7wZij8LJyh"
      },
      "outputs": [],
      "source": [
        "epochs = 2\n",
        "for epoch in range(epochs):\n",
        "    print(\"\\nStart of epoch %d\" % (epoch,))\n",
        "\n",
        "    # Iterate over the batches of the dataset.\n",
        "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
        "\n",
        "        # Open a GradientTape to record the operations run\n",
        "        # during the forward pass, which enables autodifferentiation.\n",
        "        with tf.GradientTape() as tape:\n",
        "\n",
        "            # Run the forward pass of the layer.\n",
        "            # The operations that the layer applies\n",
        "            # to its inputs are going to be recorded\n",
        "            # on the GradientTape.\n",
        "            logits = model(x_batch_train, training=True)  # Logits for this minibatch\n",
        "\n",
        "            # Compute the loss value for this minibatch.\n",
        "            loss_value = loss_fn(y_batch_train, logits)\n",
        "\n",
        "        # Use the gradient tape to automatically retrieve\n",
        "        # the gradients of the trainable variables with respect to the loss.\n",
        "        grads = tape.gradient(loss_value, model.trainable_weights)\n",
        "\n",
        "        # Run one step of gradient descent by updating\n",
        "        # the value of the variables to minimize the loss.\n",
        "        optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
        "\n",
        "        # Log every 200 batches.\n",
        "        if step % 200 == 0:\n",
        "            print(\n",
        "                \"Training loss (for one batch) at step %d: %.4f\"\n",
        "                % (step, float(loss_value))\n",
        "            )\n",
        "            print(\"Seen so far: %s samples\" % ((step + 1) * 64))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vH60dyqVSDTX"
      },
      "source": [
        "## メトリックの低レベルの処理\n",
        "\n",
        "この基本ループにメトリックの監視を追加しましょう。\n",
        "\n",
        "新規作成されたトレーニングループでは、組み込みメトリック（またはユーザーが作成したカスタムメトリック）をすぐに再利用できます。フローは以下のとおりです。\n",
        "\n",
        "- ループ開始時にメトリックをインスタンス化します\n",
        "- 各バッチの後に`metric.update_state()`を呼び出します\n",
        "- メトリックのその時点の値を表示する必要がある場合は、`metric.result()`を呼び出します\n",
        "- メトリックの状態をクリアする必要がある場合（通常はエポックの終わり）に`metric.reset_states()`を呼び出します\n",
        "\n",
        "上記を使用して、各エポックの終わりの検証データで`SparseCategoricalAccuracy`を計算してみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JUriDlmmzLey"
      },
      "outputs": [],
      "source": [
        "# Get model\n",
        "inputs = keras.Input(shape=(784,), name=\"digits\")\n",
        "x = layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
        "x = layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
        "outputs = layers.Dense(10, name=\"predictions\")(x)\n",
        "model = keras.Model(inputs=inputs, outputs=outputs)\n",
        "\n",
        "# Instantiate an optimizer to train the model.\n",
        "optimizer = keras.optimizers.SGD(learning_rate=1e-3)\n",
        "# Instantiate a loss function.\n",
        "loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
        "\n",
        "# Prepare the metrics.\n",
        "train_acc_metric = keras.metrics.SparseCategoricalAccuracy()\n",
        "val_acc_metric = keras.metrics.SparseCategoricalAccuracy()\n",
        "\n",
        "# Prepare the training dataset.\n",
        "batch_size = 64\n",
        "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
        "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)\n",
        "\n",
        "# Prepare the validation dataset.\n",
        "# Reserve 10,000 samples for validation.\n",
        "x_val = x_train[-10000:]\n",
        "y_val = y_train[-10000:]\n",
        "x_train = x_train[:-10000]\n",
        "y_train = y_train[:-10000]\n",
        "val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))\n",
        "val_dataset = val_dataset.batch(64)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IpAbe0hltQSv"
      },
      "source": [
        "トレーニングと評価のループは以下のとおりです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XYep4rMxu9D1"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "\n",
        "epochs = 2\n",
        "for epoch in range(epochs):\n",
        "    print(\"\\nStart of epoch %d\" % (epoch,))\n",
        "    start_time = time.time()\n",
        "\n",
        "    # Iterate over the batches of the dataset.\n",
        "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
        "        with tf.GradientTape() as tape:\n",
        "            logits = model(x_batch_train, training=True)\n",
        "            loss_value = loss_fn(y_batch_train, logits)\n",
        "        grads = tape.gradient(loss_value, model.trainable_weights)\n",
        "        optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
        "\n",
        "        # Update training metric.\n",
        "        train_acc_metric.update_state(y_batch_train, logits)\n",
        "\n",
        "        # Log every 200 batches.\n",
        "        if step % 200 == 0:\n",
        "            print(\n",
        "                \"Training loss (for one batch) at step %d: %.4f\"\n",
        "                % (step, float(loss_value))\n",
        "            )\n",
        "            print(\"Seen so far: %d samples\" % ((step + 1) * 64))\n",
        "\n",
        "    # Display metrics at the end of each epoch.\n",
        "    train_acc = train_acc_metric.result()\n",
        "    print(\"Training acc over epoch: %.4f\" % (float(train_acc),))\n",
        "\n",
        "    # Reset training metrics at the end of each epoch\n",
        "    train_acc_metric.reset_states()\n",
        "\n",
        "    # Run a validation loop at the end of each epoch.\n",
        "    for x_batch_val, y_batch_val in val_dataset:\n",
        "        val_logits = model(x_batch_val, training=False)\n",
        "        # Update val metrics\n",
        "        val_acc_metric.update_state(y_batch_val, val_logits)\n",
        "    val_acc = val_acc_metric.result()\n",
        "    val_acc_metric.reset_states()\n",
        "    print(\"Validation acc: %.4f\" % (float(val_acc),))\n",
        "    print(\"Time taken: %.2fs\" % (time.time() - start_time))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uIzUVMFmKOL6"
      },
      "source": [
        "## `tf.function`でトレーニングステップをスピードアップ\n",
        "\n",
        "TensorFlow 2.0 のデフォルトのランタイムは [Eager execution](https://www.tensorflow.org/guide/eager) です。そのため、上記のトレーニングループが迅速に実行されます。\n",
        "\n",
        "これはデバッグに最適ですが、グラフのコンパイルにはパフォーマンス上の明確な利点があります。計算を静的グラフとして記述すると、フレームワークでグローバルパフォーマンス最適化を適用できます。フレームワークが次の演算を知らずに、次から次へと演算を迅速に実行するように制約されている場合は不可能です。\n",
        "\n",
        "テンソルを入力として受け取る関数は、以下のように`@tf.function`デコレータを追加するだけで静的グラフにコンパイルできます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rdU2jDOB4itB"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(x, y):\n",
        "    with tf.GradientTape() as tape:\n",
        "        logits = model(x, training=True)\n",
        "        loss_value = loss_fn(y, logits)\n",
        "    grads = tape.gradient(loss_value, model.trainable_weights)\n",
        "    optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
        "    train_acc_metric.update_state(y, logits)\n",
        "    return loss_value\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SpPlBEHjXg9c"
      },
      "source": [
        "評価ステップでも同じように実行できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2yGMv9LPrvYl"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def test_step(x, y):\n",
        "    val_logits = model(x, training=False)\n",
        "    val_acc_metric.update_state(y, val_logits)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y7Elq9Kky2HQ"
      },
      "source": [
        "次に、このコンパイルされたトレーニングステップでトレーニングループを再度実行します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jb12E4zoBDOy"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "\n",
        "epochs = 2\n",
        "for epoch in range(epochs):\n",
        "    print(\"\\nStart of epoch %d\" % (epoch,))\n",
        "    start_time = time.time()\n",
        "\n",
        "    # Iterate over the batches of the dataset.\n",
        "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
        "        loss_value = train_step(x_batch_train, y_batch_train)\n",
        "\n",
        "        # Log every 200 batches.\n",
        "        if step % 200 == 0:\n",
        "            print(\n",
        "                \"Training loss (for one batch) at step %d: %.4f\"\n",
        "                % (step, float(loss_value))\n",
        "            )\n",
        "            print(\"Seen so far: %d samples\" % ((step + 1) * 64))\n",
        "\n",
        "    # Display metrics at the end of each epoch.\n",
        "    train_acc = train_acc_metric.result()\n",
        "    print(\"Training acc over epoch: %.4f\" % (float(train_acc),))\n",
        "\n",
        "    # Reset training metrics at the end of each epoch\n",
        "    train_acc_metric.reset_states()\n",
        "\n",
        "    # Run a validation loop at the end of each epoch.\n",
        "    for x_batch_val, y_batch_val in val_dataset:\n",
        "        test_step(x_batch_val, y_batch_val)\n",
        "\n",
        "    val_acc = val_acc_metric.result()\n",
        "    val_acc_metric.reset_states()\n",
        "    print(\"Validation acc: %.4f\" % (float(val_acc),))\n",
        "    print(\"Time taken: %.2fs\" % (time.time() - start_time))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kG0YiXFj519D"
      },
      "source": [
        "スピードアップしました。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LtECMpVqBAl8"
      },
      "source": [
        "## モデルにより追跡される損失の低レベルの処理\n",
        "\n",
        "レイヤーとモデルは、`self.add_loss(value)`を呼び出すレイヤーによるフォワードパス中に発生した損失を再帰的に追跡します。結果のスカラー損失値のリストは、フォワードパスの最後にあるプロパティ`model.losses`に表示されます。\n",
        "\n",
        "これらの損失を使用する場合は、合計して、トレーニングステップの主な損失に追加する必要があります。\n",
        "\n",
        "次のレイヤーはアクティビティの正規化の損失を作成します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EfiB6akb2kBT"
      },
      "outputs": [],
      "source": [
        "class ActivityRegularizationLayer(layers.Layer):\n",
        "    def call(self, inputs):\n",
        "        self.add_loss(1e-2 * tf.reduce_sum(inputs))\n",
        "        return inputs\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FuaN95bX7txO"
      },
      "source": [
        "これを使用する非常にシンプルなモデルを構築しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u8JIjdc6sWRo"
      },
      "outputs": [],
      "source": [
        "inputs = keras.Input(shape=(784,), name=\"digits\")\n",
        "x = layers.Dense(64, activation=\"relu\")(inputs)\n",
        "# Insert activity regularization as a layer\n",
        "x = ActivityRegularizationLayer()(x)\n",
        "x = layers.Dense(64, activation=\"relu\")(x)\n",
        "outputs = layers.Dense(10, name=\"predictions\")(x)\n",
        "\n",
        "model = keras.Model(inputs=inputs, outputs=outputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xWEGiFOAg2Pw"
      },
      "source": [
        "トレーニングステップは次のようになります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7dDTCiM3ikkt"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(x, y):\n",
        "    with tf.GradientTape() as tape:\n",
        "        logits = model(x, training=True)\n",
        "        loss_value = loss_fn(y, logits)\n",
        "        # Add any extra losses created during the forward pass.\n",
        "        loss_value += sum(model.losses)\n",
        "    grads = tape.gradient(loss_value, model.trainable_weights)\n",
        "    optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
        "    train_acc_metric.update_state(y, logits)\n",
        "    return loss_value\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ht7Utep1WEhg"
      },
      "source": [
        "## まとめ\n",
        "\n",
        "ここでは、組み込みのトレーニングループを使用して、独自のループを新規作成する方法を見てきました。\n",
        "\n",
        "最後に、このガイドで学習したすべてをまとめる簡単なエンドツーエンドの例を見てみましょう。MNIST の数字データでトレーニングされた DCGAN です。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g1HCZDicyRMP"
      },
      "source": [
        "## エンドツーエンドの例：GAN トレーニングループの新規作成\n",
        "\n",
        "ご存知の方もいらっしゃると思いますが、敵対的生成ネットワーク（GAN）は、画像のトレーニングデータセット（画像の「潜在空間」）の潜在的な分布を学習することにより、ほぼ現実的な新しい画像を生成できます。\n",
        "\n",
        "GAN は「ジェネレータ」モデルと「ディスクリミネータ」モデルの 2 つで構成されています。「ジェネレータ」モデルは潜在空間の点を画像空間の点にマッピングし、「ディスクリミネータ」モデルは、実物の画像（トレーニングデータセットから）と偽の画像（ジェネレータネットワークの出力）の違いを判別する分類器です。\n",
        "\n",
        "GANトレーニングループは以下のとおりです。\n",
        "\n",
        "1. ディスクリミネータをトレーニングする。\n",
        "\n",
        "- 潜在空間のランダムポイントのバッチをサンプリングする。\n",
        "- 「ジェネレータ」モデルを使用して、ポイントを偽の画像に変換する。\n",
        "- 実際の画像のバッチを取得し、生成された画像と組み合わせる。\n",
        "- 「ディスクリミネータ」モデルをトレーニングして、生成された画像と実際の画像を分類する。\n",
        "\n",
        "1. ジェネレータをトレーニングする。\n",
        "\n",
        "- 潜在空間のランダムポイントをサンプリングする。\n",
        "- 「ジェネレータ」ネットワークを介して、ポイントを偽の画像に変換する。\n",
        "- 実際の画像のバッチを取得し、生成された画像と組み合わせる。\n",
        "- 「ジェネレータ」モデルをトレーニングして、ディスクリミネータを「だまし」、偽の画像を実物として分類する。\n",
        "\n",
        "GAN の詳細については、[「Deep Learning with Python」](https://www.manning.com/books/deep-learning-with-python)を参照してください。\n",
        "\n",
        "このトレーニングループを実装しましょう。まず、偽の数字と実際の数字を分類するためのディスクリミネータを作成します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oo0K12W14lTB"
      },
      "outputs": [],
      "source": [
        "discriminator = keras.Sequential(\n",
        "    [\n",
        "        keras.Input(shape=(28, 28, 1)),\n",
        "        layers.Conv2D(64, (3, 3), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Conv2D(128, (3, 3), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.GlobalMaxPooling2D(),\n",
        "        layers.Dense(1),\n",
        "    ],\n",
        "    name=\"discriminator\",\n",
        ")\n",
        "discriminator.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CtDX25Xojcan"
      },
      "source": [
        "次に、潜在的なベクトルを形状`(28, 28, 1)`（MNISTの数字を表す）の出力に変換するジェネレータネットワークを作成します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y01N2V7Ykamk"
      },
      "outputs": [],
      "source": [
        "latent_dim = 128\n",
        "\n",
        "generator = keras.Sequential(\n",
        "    [\n",
        "        keras.Input(shape=(latent_dim,)),\n",
        "        # We want to generate 128 coefficients to reshape into a 7x7x128 map\n",
        "        layers.Dense(7 * 7 * 128),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Reshape((7, 7, 128)),\n",
        "        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Conv2D(1, (7, 7), padding=\"same\", activation=\"sigmoid\"),\n",
        "    ],\n",
        "    name=\"generator\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1D0KbjVFlyye"
      },
      "source": [
        "ここで重要なのが、トレーニングループです。ご覧のとおり、非常に簡単です。トレーニングステップの関数は 17 行だけです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MGbCyv5MNPme"
      },
      "outputs": [],
      "source": [
        "# Instantiate one optimizer for the discriminator and another for the generator.\n",
        "d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)\n",
        "g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)\n",
        "\n",
        "# Instantiate a loss function.\n",
        "loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)\n",
        "\n",
        "\n",
        "@tf.function\n",
        "def train_step(real_images):\n",
        "    # Sample random points in the latent space\n",
        "    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))\n",
        "    # Decode them to fake images\n",
        "    generated_images = generator(random_latent_vectors)\n",
        "    # Combine them with real images\n",
        "    combined_images = tf.concat([generated_images, real_images], axis=0)\n",
        "\n",
        "    # Assemble labels discriminating real from fake images\n",
        "    labels = tf.concat(\n",
        "        [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0\n",
        "    )\n",
        "    # Add random noise to the labels - important trick!\n",
        "    labels += 0.05 * tf.random.uniform(labels.shape)\n",
        "\n",
        "    # Train the discriminator\n",
        "    with tf.GradientTape() as tape:\n",
        "        predictions = discriminator(combined_images)\n",
        "        d_loss = loss_fn(labels, predictions)\n",
        "    grads = tape.gradient(d_loss, discriminator.trainable_weights)\n",
        "    d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))\n",
        "\n",
        "    # Sample random points in the latent space\n",
        "    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))\n",
        "    # Assemble labels that say \"all real images\"\n",
        "    misleading_labels = tf.zeros((batch_size, 1))\n",
        "\n",
        "    # Train the generator (note that we should *not* update the weights\n",
        "    # of the discriminator)!\n",
        "    with tf.GradientTape() as tape:\n",
        "        predictions = discriminator(generator(random_latent_vectors))\n",
        "        g_loss = loss_fn(misleading_labels, predictions)\n",
        "    grads = tape.gradient(g_loss, generator.trainable_weights)\n",
        "    g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))\n",
        "    return d_loss, g_loss, generated_images\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UwFhElNElUpn"
      },
      "source": [
        "画像のバッチに対して繰り返し`train_step`を呼び出して、GANをトレーニングします。\n",
        "\n",
        "ジェネレータとディスクリミネータは畳み込みニューラルネットワークなので、このコードを GPU で実行する必要があります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TfCAGhKdDEVa"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "# Prepare the dataset. We use both the training &amp; test MNIST digits.\n",
        "batch_size = 64\n",
        "(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n",
        "all_digits = np.concatenate([x_train, x_test])\n",
        "all_digits = all_digits.astype(\"float32\") / 255.0\n",
        "all_digits = np.reshape(all_digits, (-1, 28, 28, 1))\n",
        "dataset = tf.data.Dataset.from_tensor_slices(all_digits)\n",
        "dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)\n",
        "\n",
        "epochs = 1  # In practice you need at least 20 epochs to generate nice digits.\n",
        "save_dir = \"./\"\n",
        "\n",
        "for epoch in range(epochs):\n",
        "    print(\"\\nStart epoch\", epoch)\n",
        "\n",
        "    for step, real_images in enumerate(dataset):\n",
        "        # Train the discriminator &amp; generator on one batch of real images.\n",
        "        d_loss, g_loss, generated_images = train_step(real_images)\n",
        "\n",
        "        # Logging.\n",
        "        if step % 200 == 0:\n",
        "            # Print metrics\n",
        "            print(\"discriminator loss at step %d: %.2f\" % (step, d_loss))\n",
        "            print(\"adversarial loss at step %d: %.2f\" % (step, g_loss))\n",
        "\n",
        "            # Save one generated image\n",
        "            img = tf.keras.preprocessing.image.array_to_img(\n",
        "                generated_images[0] * 255.0, scale=False\n",
        "            )\n",
        "            img.save(os.path.join(save_dir, \"generated_img\" + str(step) + \".png\"))\n",
        "\n",
        "        # To limit execution time we stop after 10 steps.\n",
        "        # Remove the lines below to actually train the model!\n",
        "        if step &gt; 10:\n",
        "            break"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RF070l5Ik9xo"
      },
      "source": [
        "以上です。Colab GPU で約 30 秒のトレーニングを行うと、実物のような偽の MNIST の数字データが表示されます。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "writing_a_training_loop_from_scratch.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
